{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e80e23-7227-4e5f-83ba-bbfd4fd430ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aafd3450-fde7-44fe-aef7-e631df94bff2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from modeling_llama import LlamaForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707ea855-07a2-4544-8edd-bfeea3d4d773",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.Linear.reset_parameters = lambda x: None\n",
    "model = LlamaForCausalLM.from_pretrained('../llama-2-13b/', torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277a49c1-0fd8-4a41-bf62-7acd877c44f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to('cuda:0').eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d5dc125-c5d0-49ff-b929-d70a5cbca2eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('../llama-2-13b/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9808aa5d-eaa9-4f26-b147-623bf279e94f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e3620ee-5cb8-41d3-b183-aaeaad245163",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "xsum = load_dataset('xsum').shuffle(4242)\n",
    "cnn = load_dataset('cnn_dailymail', '3.0.0').shuffle(4242)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1192abb-d217-4778-b7a9-17c00c4f8671",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    item = cnn['train'][i+100]\n",
    "    cnn_context = 'Article: ' + item['article'] + '\\nSummary: ' + item['highlights'].replace('\\n', '')\n",
    "    \n",
    "    item = cnn['train'][i]\n",
    "    prompt = cnn_context + '\\nArticle: ' + item['article'] + '\\nSummary:'\n",
    "    prompts.append(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea3ed14c-e0a7-4c79-81bb-1e97e87e8d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    item = xsum['train'][i+100]\n",
    "    xsum_context = 'Article: ' + item['document'] + '\\nSummary: ' + item['summary'].replace('\\n', '')\n",
    "    \n",
    "    item = xsum['train'][i]\n",
    "    prompt = xsum_context + '\\nArticle: ' + item['document'] + '\\nSummary:'\n",
    "    prompts.append(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4df0eccd-4f1c-4043-8ed1-af4a2078e89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from searching import LayerSkippingSearching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "671854ee-2992-46f1-a43f-477afcd76bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_searching = LayerSkippingSearching(model, tokenizer, prompts, evaluate_config={\"generate_fn\": \"essg\", \"max_new_tokens\": 32})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53d16ef-1225-499d-b8fd-f5e67d2fc1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_searching.probe([8,10,15,18,20,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,], [])\n",
    "layer_searching.probe([3, 5, 6, 8, 10, 11, 14, 15, 18, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 34, 35, 36, 37], [6, 9, 10, 11, 15, 24, 25, 27, 28, 35])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45336d0f-21a0-43d8-8c54-a4a819040b3b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "layer_searching.search(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32f50131-abea-4a35-b9a1-f13a83c73e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_searching.get_solution()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a11a90a5-3e79-49c1-8ac9-7c5d1b9305be",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nebula-fav2",
   "language": "python",
   "name": "nebula-fav2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
